import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import datetime
import json
import random
import sys

import numpy as np

cwd = os.getcwd()
sys.path.append(cwd.replace('/interface', ''))
print(sys.path)
from player_ranking.player_evaluation_metric import PlayerRanking, run_risk_sensitive_player_evaluation
from generic.model_util import get_distrib_q_model_save_path
# from torch import optim
from agent import SportsAgent
from evaluate.evaluate_distrib_rl import generate_game_plot, contextualized_empirical_risk_measure, \
    visualize_exp_variance_by_location, visualize_uncertainty_by_location
from generic.data_util import read_args, load_config, load_event_data, HistoryScoreCache, ICEHOCKEY_ACTIONS, \
    divide_dataset_according2date


def train(args):
    config, debug_mode, log_file_path = load_config(args)
    if log_file_path is not None:
        log_file = open(log_file_path, 'w')
    else:
        log_file = None

    if args.DEBUG_MODE:
        debug_mode = True
        debug_msg = 'debug_'
        # config['general']['model']['max_trace_length'] = 1
        # config['general']['training']['batch_size'] = 64
    else:
        debug_mode = False
        debug_msg = ''
    sanity_check_msg = None

    if args.LEARN_MODE == 'location_ha':
        print('-' * 100, file=log_file, flush=True)
        print("*** Warning: Launching the sanity check. ***", file=log_file, flush=True)
        config['general']['model']['input_dim'] = len(ICEHOCKEY_ACTIONS) + 4
        sanity_check_msg = 'sanity_check_location_ha_'  # sanity_check_location_ha_, sanity_check_sd_md_tr_ha_
        debug_msg = sanity_check_msg + debug_msg
        print('-' * 100, file=log_file, flush=True)
    elif args.LEARN_MODE == 'normal':
        pass
    else:
        raise ValueError("Unknown learning mode {0}".format(args.LEARN_MODE))

    print(json.dumps(config, indent=4), file=log_file, flush=True)

    agent = SportsAgent(config=config, log_file=log_file)
    all_files = sorted(os.listdir(agent.train_data_path))
    training_files, _, _ = divide_dataset_according2date(all_data_files=all_files,
                                                         train_rate=agent.train_rate,
                                                         sports=agent.sports,
                                                         if_split=agent.apply_data_date_div
                                                         )

    today = datetime.date.today()
    running_avg_dqn_loss = HistoryScoreCache(capacity=500)
    episode_num = 0

    if args.CHECK_POINT is not None:
        date_label = args.CHECK_POINT
    else:
        date_label = datetime.datetime.now().strftime('%b-%d-%Y-%H:%M')

    model_save_mother_dir = get_distrib_q_model_save_path(agent, date_label, debug_msg)

    if debug_mode:
        load_from_path = model_save_mother_dir.replace(debug_msg, '')
    else:
        load_from_path = model_save_mother_dir

    if not os.path.exists(model_save_mother_dir):
        os.mkdir(model_save_mother_dir)

    min_erm_evaluate = float('inf')
    min_std_evaluate = float('inf')
    min_improve_correl = -float('inf')
    if args.CHECK_POINT is not None and os.path.isfile(load_from_path):
        _, episode_num, min_erm_evaluate, min_std_evaluate, min_improve_correl = \
            agent.load_pretrained_model(load_from=load_from_path,
                                        log_file=log_file)

    # i = 0
    while episode_num <= agent.max_episode:
        for file_name in training_files:
            s_a_sequence, r_sequence = agent.load_sports_data(game_label=file_name,
                                                              sanity_check_msg=sanity_check_msg)
            pid_sequence = agent.load_player_id(game_label=file_name)
            if agent.apply_rnn:
                transition_all = agent.build_rnn_transitions(s_a_data=s_a_sequence,
                                                             r_data=r_sequence,
                                                             pid_sequence=pid_sequence)
            else:
                transition_all = agent.build_transitions(s_a_data=s_a_sequence,
                                                         r_data=r_sequence,
                                                         pid_sequence=pid_sequence)

            counter = 0
            end = False
            while not end:
                batch_data = agent.get_transition_batch(transition_all=transition_all, counter=counter)
                loss = agent.update_distrib_dqn_model(num_tau=agent.num_tau, batch=batch_data)
                if (counter + 1) * agent.batch_size >= len(transition_all):
                    end = True
                counter += 1
                running_avg_dqn_loss.push(loss.detach().cpu().item())
            episode_num += 1

            if episode_num % agent.update_target_frequency < (episode_num - 1) % agent.update_target_frequency:
                agent.update_target_net()

            if debug_mode or episode_num % int(float(agent.report_frequency) / 10) < (episode_num - 1) % int(
                    float(agent.report_frequency) / 10):
                # plot the values for a random game
                mean_game_std, max_game_std, min_game_std = \
                    generate_game_plot(agent, training_files[15],
                                       date_label=date_label,
                                       episode_num=episode_num,
                                       sanity_check_msg=sanity_check_msg,
                                       debug_msg=debug_msg)
                print("Episode: {0}, Training Loss: {1}.\n\n".format(
                    episode_num,
                    running_avg_dqn_loss.get_avg(),
                ), file=log_file, flush=True)

            if debug_mode or episode_num % agent.report_frequency < (episode_num - 1) % agent.report_frequency:

                return_correlations, _ = run_risk_sensitive_player_evaluation(agent=agent,
                                                                              model_save_path=model_save_mother_dir,
                                                                              iteration=episode_num,
                                                                              log_file=log_file,
                                                                              sanity_check_msg=sanity_check_msg,
                                                                              debug_mode=debug_mode,
                                                                              debug_msg=debug_msg, )
                corrcoef_values_origin = return_correlations["{0}".format(float('inf'))]

                check_alphas = list(return_correlations.keys())
                check_alphas.remove("{0}".format(float('inf')))
                improve_all = []
                for alpha in check_alphas:
                    corrcoef_values_check = return_correlations[alpha]
                    difference = np.asarray(corrcoef_values_check) - np.asarray(corrcoef_values_origin)
                    improve_all.append(np.mean(difference))
                improve_corrcoef_sum = np.sum(improve_all)
                corrcoef_avg = np.mean(corrcoef_values_origin)

                contextualized_empirical_risk_measure(agent=agent,
                                                      model_label=debug_msg + agent.sports +
                                                                  model_save_mother_dir.split('model')[-1],
                                                      episode_num=episode_num,
                                                      sanity_check_msg=sanity_check_msg,
                                                      debug_mode=debug_mode,
                                                      sports=agent.sports,
                                                      target_action='all')

                # visualize_exp_variance_by_location(agent=agent,
                #                                    debug_mode=debug_mode,
                #                                    sanity_check_msg=sanity_check_msg,
                #                                    episode_num=episode_num,
                #                                    mode='test',
                #                                    model_label=debug_msg + '_heatmap' +
                #                                                model_save_mother_dir.split('model')[-1])
                # visualize_uncertainty_by_location(agent=agent,
                #                                   debug_mode=debug_mode,
                #                                   sanity_check_msg=sanity_check_msg,
                #                                   episode_num=episode_num,
                #                                   mode='test',
                #                                   model_label=debug_msg + '_heatmap' +
                #                                               model_save_mother_dir.split('model')[-1],
                #                                   log_file=log_file)

                if improve_corrcoef_sum > min_improve_correl:
                    print("Saving the best mode with evaluation avg correl: {0} and improved correl: {1}".format(
                        corrcoef_avg, improve_corrcoef_sum), file=log_file, flush=True)
                    # print("Saving the best mode with evaluation ERM: {0} and SRM {1}".format(
                    #     erm_evaluate, std_evaluate), file=log_file, flush=True)
                    # if erm_evaluate < min_erm_evaluate:
                    #     min_erm_evaluate = erm_evaluate
                    # if std_evaluate < min_std_evaluate:
                    #     min_std_evaluate = std_evaluate
                    agent.save_model_to_path(save_to_path=model_save_mother_dir + '/saved_model_best',
                                             # eval_distance=None,
                                             # std_distance=None,
                                             coorelation=corrcoef_avg,
                                             episode_no=episode_num,
                                             log_file=log_file
                                             )
                    min_improve_correl = corrcoef_avg

                print("Episode: {0}, Training Loss: {1}, avg correl: {2} and improved correl: {3} "
                      "Example Games std: (mean:{4}, max:{5} and min:{6}).".format(
                    episode_num,
                    running_avg_dqn_loss.get_avg(),
                    corrcoef_avg,
                    improve_corrcoef_sum,
                    mean_game_std,
                    max_game_std,
                    min_game_std,), file=log_file, flush=True)
                # print(corrcoef_string, file=log_file, flush=True)

                agent.save_model_to_path(save_to_path=model_save_mother_dir + '/saved_model_{0}'.format(episode_num),
                                         # eval_distance=None,
                                         # std_distance=None,
                                         coorelation=corrcoef_avg,
                                         episode_no=episode_num,
                                         log_file=log_file
                                         )


def test(args):
    raise ValueError("please use the run_hockey_evaluate for testing")


if __name__ == "__main__":
    args = read_args()
    if int(args.TRAIN_FLAG):
        train(args)
    else:
        test(args)
